{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch_geometric.utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-09T14:27:04.158229Z",
     "start_time": "2021-06-09T14:27:02.685113Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_geometric.utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## degree\n",
    "根据所给索引，计算节点的度\n",
    "### 参数\n",
    "+ index (LongTensor)  - 边连接的节点的索引\n",
    "+ num_nodes (int, optional) - 节点的数量\n",
    "+ dtype(torch.dtype, optional) - 需要的返回类型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 样例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-09T12:07:38.143008Z",
     "start_time": "2021-06-09T12:07:38.127439Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2, 1, 0],\n",
      "        [3, 2, 3]])\n",
      "头节点: tensor([2, 1, 0])\n",
      "尾节点: tensor([3, 2, 3])\n",
      "出度: tensor([1., 1., 1.])\n",
      "入度: tensor([0., 0., 1., 2.])\n",
      "出度: tensor([1., 1., 1., 0.])\n",
      "入度: tensor([0., 0., 1., 2.])\n"
     ]
    }
   ],
   "source": [
    "a = torch.LongTensor([[2,3],[1,2],[0,3]]).t()\n",
    "print(a)\n",
    "headindex, tailindex = a\n",
    "print(\"头节点:\",headindex)\n",
    "print(\"尾节点:\",tailindex)\n",
    "print(\"出度:\",degree(headindex))\n",
    "# tensor([1., 1., 1.])\n",
    "print(\"入度:\",degree(tailindex))\n",
    "# tensor([0., 0., 1., 2.])\n",
    "\n",
    "# 这里可以看到，出度和入度输出的维度不同，因为节点3不作为头节点，所以默认只有3个节点，如果要得到一致维度的输出，要额外传入节点个数\n",
    "print(\"出度:\",degree(headindex, num_nodes = 4))\n",
    "# tensor([1., 1., 1., 0., 0.])\n",
    "print(\"入度:\",degree(tailindex, num_nodes = 4))\n",
    "# tensor([0., 0., 1., 2., 0.])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## add_self_loops\n",
    "给图添加自环\n",
    "### 参数\n",
    "+ edge_index (LongTensor) – 边索引.\n",
    "+ edge_weight (Tensor, optional) – 边权重(default: None)\n",
    "+ fill_value (float, optional) – 如果边权重为非空，则会将fill_value填给自环，作为自环的权重 (default:1)\n",
    "+ num_nodes (int, optional) – 节点的数量 (default: None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "### 样例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-09T12:15:50.516457Z",
     "start_time": "2021-06-09T12:15:50.504420Z"
    },
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "无权重: (tensor([[2, 1, 0, 0, 1, 2, 3],\n",
      "        [3, 2, 3, 0, 1, 2, 3]]), None)\n",
      "有权重: (tensor([[2, 1, 0, 0, 1, 2, 3],\n",
      "        [3, 2, 3, 0, 1, 2, 3]]), tensor([0.5000, 0.5000, 0.5000, 0.9000, 0.9000, 0.9000, 0.9000]))\n"
     ]
    }
   ],
   "source": [
    "a = torch.LongTensor([[2,3],[1,2],[0,3]]).t()\n",
    "w = torch.FloatTensor([0.5, 0.5, 0.5])\n",
    "\n",
    "print(\"无权重:\",add_self_loops(a))\n",
    "# (tensor([[2, 1, 0, 0, 1, 2, 3],\n",
    "#         [3, 2, 3, 0, 1, 2, 3]]), None)\n",
    "print(\"有权重:\",add_self_loops(a,w,0.9))\n",
    "# (tensor([[2, 1, 0, 0, 1, 2, 3],\n",
    "#         [3, 2, 3, 0, 1, 2, 3]]), tensor([0.5000, 0.5000, 0.5000, 0.9000, 0.9000, 0.9000, 0.9000]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## add_remaining_self_loops\n",
    "防止图是加权的且已有部分自环，用本函数添加剩余自环，比add_self_loops更精确\n",
    "### 参数\n",
    "+ edge_index (LongTensor) – 边索引.\n",
    "+ edge_weight (Tensor, optional) – 边权重(default: None)\n",
    "+ fill_value (float, optional) – 如果边权重为非空，则会将fill_value填给自环，作为自环的权重 (default:1)\n",
    "+ num_nodes (int, optional) – 节点的数量 (default: None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 样例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-09T14:28:51.083556Z",
     "start_time": "2021-06-09T14:28:51.072654Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "添加自环: (tensor([[2, 1, 0, 1, 0, 1, 2, 3],\n",
      "        [3, 2, 3, 1, 0, 1, 2, 3]]), None)\n",
      "添加剩余自环: (tensor([[2, 1, 0, 0, 1, 2, 3],\n",
      "        [3, 2, 3, 0, 1, 2, 3]]), None)\n"
     ]
    }
   ],
   "source": [
    "a = torch.LongTensor([[2,3],[1,2],[0,3],[1,1]]).t()\n",
    "w = torch.FloatTensor([0.5, 0.5, 0.5])\n",
    "\n",
    "# 比较两种添加自环方式，add_self_loops返回了两次(1,1)\n",
    "print(\"添加自环:\",add_self_loops(a))\n",
    "print(\"添加剩余自环:\",add_remaining_self_loops(a))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## num_nodes.maybe_num_nodes\n",
    "返回节点数\n",
    "### 参数\n",
    "+ edge_index - 边索引，如果边索引为Tensor类型\n",
    "+ num_nodes - 节点数，默认为None  \n",
    "\n",
    "### 代码   \n",
    "```python\n",
    "def maybe_num_nodes(edge_index, num_nodes=None):  \n",
    "    if num_nodes is not None:  \n",
    "        return num_nodes  \n",
    "    elif isinstance(edge_index, Tensor):  \n",
    "        return int(edge_index.max()) + 1  \n",
    "    else:  \n",
    "        return max(edge_index.size(0), edge_index.size(1))\n",
    "```\n",
    "如果传入节点数，直接返回节点数  \n",
    "如果edge_index是Tensor格式，返回最大值+1，否则返回较大维度值"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "song",
   "language": "python",
   "name": "song"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
